from flask import Flask,send_file,request,jsonify
import os
import json
import subprocess
import pymysql
pymysql.install_as_MySQLdb()
from flask_sqlalchemy import SQLAlchemy
import datetime
import math
from flask_executor import Executor
from flask_cors import CORS
import time

app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False

app.config['SQLALCHEMY_DATABASE_URI']=''

db = SQLAlchemy(app)
executor = Executor(app)
CORS(app, supports_credentials=True)

class Evaluate_task(db.Model):
    __tablename__ = 'evaluate_task'
    id = db.Column(db.Integer, primary_key=True,autoincrement=True)
    model_id = db.Column(db.Integer)
    start_time = db.Column(db.String(64))
    dataset_id = db.Column(db.Integer)
    status = db.Column(db.Integer)

    def __repr__(self):
        return 'Evaluate_task(%r)' % self.id

class Compare_task(db.Model):
    __tablename__ = 'compare_task'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    base_model_id = db.Column(db.Integer)
    compare_model_id = db.Column(db.Integer)
    task_id_0 = db.Column(db.Integer)
    task_id_1 = db.Column(db.Integer)
    start_time = db.Column(db.String(64))

    def __repr__(self):
        return 'Compare_task(%r)' % self.id

class Datasets(db.Model):
    __tablename__ = 'datasets'
    id = db.Column(db.Integer, primary_key=True, autoincrement=True)
    dataset_name = db.Column(db.String(64))
    create_time = db.Column(db.String(64))

    def __repr__(self):
        return 'dataset(%r)' % self.id

# @app.route('/reports/<id>')
# def report(id):
#     file_name = id.replace("/", "_")
#     file_name = file_name + ".html"
#     file_path = "./templates/" + file_name
#     if os.path.exists(file_path):
#         return send_file('file_name')
#     else:
#         return "评测任务运行中，请运行结束后查看。"

def get_now_time_str():
    # yyyy-mmmm-dddd hhhh:mmm:ssss
    time_now = datetime.datetime.now()
    bj_time = time_now + datetime.timedelta(hours=8)
    bj_time_now_str = bj_time.strftime("%Y-%m-%d %H:%M:%S")
    return bj_time_now_str

@app.route('/compare', methods=["POST", "GET"])
def compare():
    task_id_0 = ""
    task_id_1 = ""

    if request.method =="POST":
        data = json.loads(request.data)
        task_id_0 = data["base_task"]
        task_id_1 = data["compare_task"]
        # task_id_0 = request.form.get("task_id_0")
        # task_id_1 = request.form.get("task_id_1")
        # compare_task_id = request.form.get("compare_task_id")
    else:
        task_id_0 = request.args.get("base_task")
        task_id_1 = request.args.get("compare_task")


    # print(task_id_0)
    # print(task_id_1)
    # print(compare_task_id)
    # command = "bash /datas/diff_start_tool_.sh -b " + task_id_0 + " -c " + task_id_1 + " -o " + compare_task_id
    # os.system(command)

    # file_name = compare_task_id.replace("/", "_")
    # report_url = "http://106.13.14.17:18888/reports/" + file_name

    #Evaluate_task.query.get(1).name
    #
    time_now = get_now_time_str()
    base_model_id = Evaluate_task.query.filter(Evaluate_task.id == task_id_0).all()[0].model_id
    compare_model_id = Evaluate_task.query.filter(Evaluate_task.id == task_id_1).all()[0].model_id
    #print("c_model_id: ",compare_model_id)

    #print("base_model_id: ",base_model_id)

    c_t = Compare_task(base_model_id=base_model_id, compare_model_id=compare_model_id, task_id_0=task_id_0,
        task_id_1=task_id_1,start_time=time_now)
    db.session.add(c_t)
    db.session.commit()
    task_id = c_t.id
    data = {
        "result_code":0
    }
    return jsonify(data)

def get_compare_metrics():
    result_string_1 = ""
    with open("/datas/tmp/test_result_1.txt", "r") as f:
        result_string_1 = f.read()
    result_string_2 = ""
    with open("/datas/tmp/test_result_2.txt", "r") as f:
        result_string_2 = f.read()

    print(result_string_1)
    B_TN = 1
    B_FN = 1
    B_TP = 1
    B_FP = 1
    result_elements = result_string_1.split(' ')

    for i in range(len(result_elements)):
        if result_elements[i] == "TN":
            B_TN += int(result_elements[i + 1])
        if result_elements[i] == "FN":
            B_FN += int(result_elements[i + 1])
        if result_elements[i] == "TP":
            B_TP += int(result_elements[i + 1])
        if result_elements[i] == "FP":
            B_FP += int(result_elements[i + 1])
    print("B_TN: ",B_TN)
    print("B_FN: ",B_FN)
    print("B_TP: ",B_TP)
    print("B_FP: ",B_FP)

    B_E = (1.0* (B_FN + B_FP) )/(B_TP + B_FN + B_FP + B_TN)
    B_ACC = (1.0* (B_TN + B_TP))/(B_TP + B_FN + B_FP + B_TN)
    B_Pr = (1.0* B_TP )/(B_TP + B_FP)
    B_Re = (1.0* B_TP )/(B_TP + B_FN)
    B_F1_Score = (2 * B_Pr * B_Re) / (B_Pr + B_Re)
    B_Gmean = math.sqrt((B_TP/(B_TP+B_FN))* (B_TN/(B_TN+B_FP)))
    B_TPR = (1.0 * B_TP )/(B_TP + B_FN)
    B_FPR = (1.0 * B_FP )/(B_FP + B_TN)

    print(result_string_2)
    C_TN = 1
    C_FN = 1
    C_TP = 1
    C_FP = 1
    result_elements = result_string_2.split(' ')

    for i in range(len(result_elements)):
        if result_elements[i] == "TN":
            C_TN += int(result_elements[i + 1])
        if result_elements[i] == "FN":
            C_FN += int(result_elements[i + 1])
        if result_elements[i] == "TP":
            C_TP += int(result_elements[i + 1])
        if result_elements[i] == "FP":
            C_FP += int(result_elements[i + 1])
    print("C_TN: ",C_TN)
    print("C_FN: ",C_FN)
    print("C_TP: ",C_TP)
    print("C_FP: ",C_FP)

    C_E = (1.0* (C_FN + C_FP) )/(C_TP + C_FN + C_FP + C_TN)
    C_ACC = (1.0* (C_TN + C_TP))/(C_TP + C_FN + C_FP + C_TN)
    C_Pr = (1.0* C_TP )/(C_TP + C_FP)
    C_Re = (1.0* C_TP )/(C_TP + C_FN)
    C_F1_Score = (2 * C_Pr * C_Re) / (C_Pr + C_Re)
    C_Gmean = math.sqrt((C_TP/(C_TP+C_FN))* (C_TN/(C_TN+C_FP)))
    C_TPR = (1.0 * C_TP )/(C_TP + C_FN)
    C_FPR = (1.0 * C_FP )/(C_FP + C_TN)

    print("E: ",C_E)
    print("ACC: ",C_ACC)
    print("Pr: ",C_Pr)
    print("Re: ",C_Re)

    Diff_TN = B_TN - C_TN
    Diff_FN = B_FN - C_FN
    Diff_TP = B_TP - C_TP
    Diff_FP = B_FP - C_FP

    Diff_E = B_E - C_E
    Diff_ACC = B_ACC - C_ACC
    Diff_Pr = B_Pr - C_Pr
    Diff_Re = B_Re - C_Re
    Diff_F1_Score = B_F1_Score -C_F1_Score
    Diff_Gmean = B_Gmean - C_Gmean
    Diff_TPR = B_TPR - C_TPR
    Diff_FPR = B_FPR - C_FPR

    analysis1 = ""
    analysis2 = ""
    ana1 = ['被比较模型的错误率(E)、分类准确率(ACC)、精度(Precision)、查全率(Recall)等指标综合表现
更好。 ',
        '被比较模型的错误率(E)、分类准确率(ACC)、精度(Precision)、查全率(Recall)等指标综合表现不如
比较模型，建议用户通过标准化/归一化、正则化、添加dropout层来尝试优化。 ']
    ana2 = ['被比较模型对正常日志判断准确率高，对异常日志判断准确率低，可能存在 数据不平衡的问题，
建议用户通过过采样、欠采样、样本加权等方式进行优化。',
        '被比较模型对正常日志和异常日志判断准确率都较高，不存在数据不平衡的问题 。']

    if Diff_ACC + Diff_Pr + Diff_Re - Diff_E > 0:
        analysis1 = ana1[0]
    else:
        analysis1 = ana1[1]

    if B_FPR <0.2 and B_TPR < 0.6:
        analysis2 = ana2[0]
    else:
        analysis2 = ana2[1]

    analysis = analysis1 + analysis2
    # print(output_path)
    # print(TN)
    # print(FN)
    # print(TP)
    # print(FP)
    # print(E)
    # print(ACC)
    # print(Pr)
    # print(Re)
    return B_E, C_E, Diff_E,  B_ACC, C_ACC, Diff_ACC,B_Pr, C_Pr, Diff_Pr,B_Re, C_Re, Diff_Re,B_F1_Score, C_F1_Score, Diff_F1_Score,B_Gmean, C_Gmean, Diff_Gmean,B_TPR, C_TPR, Diff_TPR,B_FPR, C_FPR, Diff_FPR,analysis


def get_model_name(model_id):
    if model_id == 1:
        return "Lr"
    elif model_id == 2:
        return "DecisionTree"
    elif model_id == 3:
        return "Svm"
    elif model_id == 4:
        return "DeepLog"
    elif model_id == 5:
        return "LogAnomaly"
    else:
        return "SVM"

def get_dataset_name(dataset_id):
    dataset_name = Datasets.query.filter(Datasets.id == dataset_id).all()[0].dataset_name
    return dataset_name

@app.route('/compare_details', methods=["POST", "GET"])
def compare_details():
    diff_task_id = 0
    if request.method =="POST":
        data = json.loads(request.data)
        diff_task_id = data["diff_task_id"]
        # task_id_0 = request.form.get("task_id_0")
        # task_id_1 = request.form.get("task_id_1")
        # compare_task_id = request.form.get("compare_task_id")
    else:
        diff_task_id = request.args.get("diff_task_id")

    time_now = Compare_task.query.filter(Compare_task.id == diff_task_id).all()[0].start_time
    base_task_id = Compare_task.query.filter(Compare_task.id == diff_task_id).all()[0].task_id_0
    compare_task_id = Compare_task.query.filter(Compare_task.id == diff_task_id).all()[0].task_id_1

    base_model_id = Compare_task.query.filter(Compare_task.id == diff_task_id).all()[0].base_model_id
    compare_model_id = Compare_task.query.filter(Compare_task.id == diff_task_id).all()[0].compare_model_id

    base_dataset_id = Evaluate_task.query.filter(Evaluate_task.id == base_task_id).all()[0].dataset_id
    compare_dataset_id = Evaluate_task.query.filter(Evaluate_task.id == compare_task_id).all()[0].dataset_id

    base_model_name = get_model_name(base_model_id)
    compare_model_name = get_model_name(compare_model_id)

    result_path_1= "/server_output/" + str(base_task_id) + "/part-r-00000"
    result_path_2= "/server_output/" + str(compare_task_id) + "/part-r-00000"

    command = "bash /server/compare_read.sh " + result_path_1 + " " + result_path_2
    os.system(command)

    B_E, C_E, Diff_E,  B_ACC, C_ACC, Diff_ACC,B_Pr, C_Pr, Diff_Pr,B_Re, C_Re, Diff_Re,B_F1_Score, C_F1_Score, Diff_F1_Score,B_Gmean, C_Gmean, Diff_Gmean,B_TPR, C_TPR, Diff_TPR,B_FPR, C_FPR, Diff_FPR,analysis = get_compare_metrics()
    data = {
        "diff_task_id": diff_task_id,
        "base_task_id": base_task_id,
        "base_model_id": base_model_id,
        "base_dataset_id":base_dataset_id,
        "compare_task_id": compare_task_id,
        "compare_model_id": compare_model_id,
        "base_dataset_id":base_dataset_id,
        "compare_task_id": compare_task_id,
        "compare_model_id": compare_model_id,
        "compare_dataset_id": compare_dataset_id,
        "base_model_name": base_model_name,
        "compare_model_name":compare_model_name,
        "base_dataset_name":get_dataset_name(base_dataset_id),
        "compare_dataset_name":get_dataset_name(compare_dataset_id),
        "start_time": time_now,
        "scores":[
            {"metrics": "错误率(Error Rate)",
            "base_score": B_E,
            "compare_score": C_E,
            "diff_score": Diff_E},
            {"metrics": "分类准确率(ACC)",
            "base_score": B_ACC,
            "compare_score": C_ACC,
            "diff_score": Diff_ACC},
            {"metrics": "精度(Precision)",
            "base_score": B_Pr,
            "compare_score": C_Pr,
            "diff_score": Diff_Pr},
            {"metrics": "查全率(Recall)",
            "base_score": B_Re,
            "compare_score": C_Re,
            "diff_score": Diff_Re},
            {"metrics": "F1_Score",
            "base_score": B_F1_Score,
            "compare_score": C_F1_Score,
            "diff_score": Diff_F1_Score},
            {"metrics": "几何平均值(gmean)",
            "base_score": B_Gmean,
            "compare_score": C_Gmean,
            "diff_score": Diff_Gmean},
            {"metrics": "真正例率(TPR)",
            "base_score": B_TPR,
            "compare_score": C_TPR,
            "diff_score": Diff_TPR},
            {"metrics": "反正例率(FPR)",
            "base_score": B_FPR,
            "compare_score": C_FPR,
            "diff_score": Diff_FPR},
            ],

        "conclude":analysis
        }
    return jsonify(data)


@app.route('/compare_history', methods=['POST','GET'])
def compare_history():
    num = 0
    per = 0
    if request.method =="POST":
        data = json.loads(request.data)
        num = data["offset"]
        per = data["limit"]
    else:
        num = request.args.get("offset")
        per = request.args.get("limit")
    print(1)
    #pagination = Compare_task.query.order_by(Compare_task.id.desc()).paginate(offset,limit)
    print(2)
    #compare_tasks = Compare_task.query.items
    compare_tasks = Compare_task.query.order_by(Compare_task.id.desc()).offset(num * per).limit(per)

    compare_tasks_list =[]
    for task in compare_tasks:
        one_task_info = {
            "task_id":task.id,
            "base_model_id":task.base_model_id,
            "base_model_name":get_model_name(task.base_model_id),
            "compare_model_id":task.compare_model_id,
            "compare_model_name":get_model_name(task.compare_model_id),
            "start_time": task.start_time
        }
        compare_tasks_list.append(one_task_info)

    data = {
        "compare_task": compare_tasks_list
    }
    return jsonify(data)

def get_evaluate_metrics():
    result_string = ""
    with open("/datas/tmp/test_result.txt", "r") as f:
        result_string = f.read()
    print(result_string)
    TN = 0
    FN = 0
    TP = 0
    FP = 0

    result_elements = result_string.split(' ')

    for i in range(len(result_elements)):
        if result_elements[i] == "TN":
            TN += int(result_elements[i + 1])
        if result_elements[i] == "FN":
            FN += int(result_elements[i + 1])
        if result_elements[i] == "TP":
            TP += int(result_elements[i + 1])
        if result_elements[i] == "FP":
            FP += int(result_elements[i + 1])

    TN = TN + 1
    FN = FN + 1
    TP = TP + 1
    FP = FP + 1
    print("TN: ",TN)
    print("FN: ",FN)
    print("TP: ",TP)
    print("FP: ",FP)

    E = (1.0* (FN + FP) )/(TP + FN + FP + TN)
    ACC = (1.0* (TN + TP))/(TP + FN + FP + TN)
    Pr = (1.0* TP )/(TP + FP)
    Re = (1.0* TP )/(TP + FN)
    F1_Score = (2 * Pr * Re) / (Pr + Re)
    Gmean = math.sqrt((TP/(TP+FN))* (TN/(TN+FP)))
    TPR = (1.0 * TP )/(TP + FN)
    FPR = (1.0 * FP )/(FP + TN)
    return E,ACC,Pr,Re,F1_Score,Gmean,TPR,FPR

@app.route('/evaluate', methods=["POST", "GET"])
def evaluate():
    modelIDs = []
    dataSet = 1

    if request.method =="POST":
        data = json.loads(request.data)
        modelIDs = data["modelIDs"]
        dataSet = data["dataSet"]
        # task_id_0 = request.form.get("task_id_0")
        # task_id_1 = request.form.get("task_id_1")
        # compare_task_id = request.form.get("compare_task_id")
    else:
        model_id = request.args.get("model_id")
        dataset_id = request.args.get("dataset_id")

    time_now = get_now_time_str()
    dataset_id_str = ""
    if dataSet == 1:
        dataset_id_str = "dataset_1"
    elif dataSet == 2:
        dataset_id_str = "dataset_2"
    elif dataSet == 3:
        dataset_id_str = "dataset_3"
    else:
        dataset_id_str = "dataset_test"

    task_model_ids_list = []
    # 数据库里创建任务
    for e in modelIDs:
        e_t = Evaluate_task(model_id=e, start_time=time_now, dataset_id=dataSet, status=1 )
        db.session.add(e_t)
        db.session.commit()
        task_model_ids_list.append([e_t.id,e])
    print("len: task_model",len(task_model_ids_list))
    #start evaluate
    def execute_evaluation(task_model_ids_list,dataset_id_str):
        db = pymysql.connect(host='',
                        port=31111,
                        user='',
                        password='',
                        database='')
        for element in task_model_ids_list:
            model_id = ""
            dataset_id = ""
            task_id = ""
            if element[1] == 1:
                model_id_str = "Lr"
            elif element[1] == 2:
                model_id_str = "DecisionTree"
            elif element[1] == 3:
                model_id_str = "Svm"

            else:
                model_id_str = "Svm"


            task_id = element[0]
            task_id_str = "/server_output/" + str(task_id)

            command = "bash /datas/start_tool.sh -m " + model_id_str + " -d " + dataset_id_str + " -o " + task_id_str
            print(command)
            os.system(command)
            #更新status
            print("更新status")
            cursor = db.cursor()
            sql_cmd = "UPDATE evaluate_task SET status=0 WHERE id =" + str(task_id)
            cursor.execute(sql_cmd)
            db.commit()
            #eva_task = Evaluate_task.query.filter(Evaluate_task.id == report_task_id).all()[0]
            #eva_task.status = 0
            #db.session.commit()
            print("更新status成功")
            #p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)

    executor.submit(execute_evaluation,task_model_ids_list,dataset_id_str)
    data = {
        "result_code": 0
    }
    return jsonify(data)

@app.route('/evaluate_details', methods=["POST", "GET"])
def evaluate_details():
    report_task_id = 0
    if request.method =="POST":
        data = json.loads(request.data)
        report_task_id = data["report_task_id"]
        # task_id_0 = request.form.get("task_id_0")
        # task_id_1 = request.form.get("task_id_1")
        # compare_task_id = request.form.get("compare_task_id")
    else:
        report_task_id = request.args.get("report_task_id")

    start_time = Evaluate_task.query.filter(Evaluate_task.id == report_task_id).all()[0].start_time
    model_id = Evaluate_task.query.filter(Evaluate_task.id == report_task_id).all()[0].model_id
    dataset_id = Evaluate_task.query.filter(Evaluate_task.id == report_task_id).all()[0].dataset_id

    result_path= "/server_output/" + str(report_task_id) + "/part-r-00000"

    command = "bash /server/evaluate_read.sh " + result_path
    os.system(command)

    E,ACC,Pr,Re,F1_Score,Gmean,TPR,FPR = get_evaluate_metrics()
    data = {
        "report_task_id": report_task_id,
        "model_id": model_id,
        "start_time": start_time,
        "dataset_id":dataset_id,
        "model_name": get_model_name(model_id),
        "dataset_name":get_dataset_name(dataset_id),
        "scores":[
            {"metrics": "错误率(Error Rate)","score": E},
            {"metrics": "分类准确率(ACC)","score": ACC},
            {"metrics": "精度(Precision)","score": Pr},
            {"metrics": "查全率(Recall)","score": Re},
            {"metrics": "F1_Score","score": F1_Score},
            {"metrics": "几何平均值(gmean)","score": Gmean},
            {"metrics": "真正例率(TPR)","score": TPR},
            {"metrics": "反正例率(FPR)","score": FPR}
        ]
    }
    return jsonify(data)

@app.route('/evaluate_history', methods=["POST", "GET"])
def evaluate_history():
    num = 0
    per = 0
    if request.method =="POST":
        data = json.loads(request.data)
        num = data["offset"]
        per = data["limit"]
    else:
        num = request.args.get("offset")
        per = request.args.get("limit")
    evaluate_tasks = Evaluate_task.query.order_by(Evaluate_task.id.desc()).offset(num * per).limit(per)
    evaluate_tasks_list =[]
    for task in evaluate_tasks:
        one_task_info = {
            "task_id":task.id,
            "model_id":task.model_id,
            "dataset_id":task.dataset_id,
            "model_name": get_model_name(task.model_id),
            "dataset_name":get_dataset_name(task.dataset_id),
            "start_time":task.start_time,
            "status":task.status
        }
        evaluate_tasks_list.append(one_task_info)

    data = {
        "evaluate_task": evaluate_tasks_list
    }
    return jsonify(data)



@app.route('/all_datasets', methods=["POST", "GET"])
def all_datasets():
    datasets = Datasets.query
    datasets_list =[]
    for one_set in datasets:
        one_set_info = {
            "dataset_id":one_set.id,
            "dataset_name":one_set.dataset_name,
            "create_time":one_set.create_time
        }
        datasets_list.append(one_set_info)

    data = {
        "datasets": datasets_list
    }
    return jsonify(data)

@app.route('/num_evaluate_task', methods=["POST", "GET"])
def num_evaluate_task():
    num = Evaluate_task.query.count()

    data = {
        "num": num
    }
    return jsonify(data)


@app.route('/num_compare_task', methods=["POST", "GET"])
def num_compare_task():
    num = Compare_task.query.count()
    data = {
        "num": num
    }
    return jsonify(data)


app.run(host='0.0.0.0',port=8888)